import zodToJsonSchema from 'zod-to-json-schema';
import {
  LanguageModelV1,
  LanguageModelV1CallWarning,
} from '../../ai-model-specification/index';
import {
  AIStreamCallbacksAndOptions,
  createCallbacksTransformer,
  createStreamDataTransformer,
  readableFromAsyncIterable,
} from '../../streams';
import { CallSettings } from '../prompt/call-settings';
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
import { getInputFormat } from '../prompt/get-input-format';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import { ExperimentalTool } from '../tool';
import {
  AsyncIterableStream,
  createAsyncIterableStream,
} from '../util/async-iterable-stream';
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
import { runToolsTransformation } from './run-tools-transformation';
import { ToToolCall } from './tool-call';
import { ToToolResult } from './tool-result';

/**
 * Stream text generated by a language model.
 */
export async function experimental_streamText<
  TOOLS extends Record<string, ExperimentalTool>,
>({
  model,
  tools,
  system,
  prompt,
  messages,
  maxRetries,
  abortSignal,
  ...settings
}: CallSettings &
  Prompt & {
    model: LanguageModelV1;
    tools?: TOOLS;
  }): Promise<StreamTextResult<TOOLS>> {
  const retry = retryWithExponentialBackoff({ maxRetries });
  const { stream, warnings } = await retry(() =>
    model.doStream({
      mode: {
        type: 'regular',
        tools:
          tools == null
            ? undefined
            : Object.entries(tools).map(([name, tool]) => ({
                type: 'function',
                name,
                description: tool.description,
                parameters: zodToJsonSchema(tool.parameters),
              })),
      },
      ...prepareCallSettings(settings),
      inputFormat: getInputFormat({ prompt, messages }),
      prompt: convertToLanguageModelPrompt({
        system,
        prompt,
        messages,
      }),
      abortSignal,
    }),
  );

  return new StreamTextResult({
    stream: runToolsTransformation({
      tools,
      generatorStream: stream,
    }),
    warnings,
  });
}

export type TextStreamPart<TOOLS extends Record<string, ExperimentalTool>> =
  | {
      type: 'text-delta';
      textDelta: string;
    }
  | ({
      type: 'tool-call';
    } & ToToolCall<TOOLS>)
  | {
      type: 'error';
      error: unknown;
    }
  | ({
      type: 'tool-result';
    } & ToToolResult<TOOLS>);

export class StreamTextResult<TOOLS extends Record<string, ExperimentalTool>> {
  private readonly originalStream: ReadableStream<TextStreamPart<TOOLS>>;

  readonly warnings: LanguageModelV1CallWarning[] | undefined;

  constructor({
    stream,
    warnings,
  }: {
    stream: ReadableStream<TextStreamPart<TOOLS>>;
    warnings: LanguageModelV1CallWarning[] | undefined;
  }) {
    this.originalStream = stream;
    this.warnings = warnings;
  }

  get textStream(): AsyncIterableStream<string> {
    return createAsyncIterableStream(this.originalStream, {
      transform(chunk, controller) {
        if (chunk.type === 'text-delta') {
          // do not stream empty text deltas:
          if (chunk.textDelta.length > 0) {
            controller.enqueue(chunk.textDelta);
          }
        } else if (chunk.type === 'error') {
          throw chunk.error;
        }
      },
    });
  }

  get fullStream(): AsyncIterableStream<TextStreamPart<TOOLS>> {
    return createAsyncIterableStream(this.originalStream, {
      transform(chunk, controller) {
        if (chunk.type === 'text-delta') {
          // do not stream empty text deltas:
          if (chunk.textDelta.length > 0) {
            controller.enqueue(chunk);
          }
        } else {
          controller.enqueue(chunk);
        }
      },
    });
  }

  toAIStream(callbacks?: AIStreamCallbacksAndOptions) {
    // TODO add support for tool calls
    return readableFromAsyncIterable(this.textStream)
      .pipeThrough(createCallbacksTransformer(callbacks))
      .pipeThrough(
        createStreamDataTransformer(callbacks?.experimental_streamData),
      );
  }
}
